Code
DataTransformerRegistry.enable('default')
DataTransformerRegistry.enable('default')
df = read_parquet_and_reorder("df.parquet")
logger.info(df.shape)
df_per_100g = df.select("code", *[c for c in df.columns if c.endswith("_100g")])
df = df.select(c for c in df.columns if not c.endswith("_100g"))
columns = [
"categories_en",
"ingredients_tags",
"ingredients_analysis_tags",
"traces_en",
"food_groups_en",
"nutrient_levels_tags",
"main_category_en",
"packaging_en",
]
df_dict: dict[str, pl.DataFrame] = {
c: df.pipe(one_hot_encode, c, n=10, remove_prefix=["en:", "de:"]) for c in columns
} | {"nutrients": df_per_100g}[06/30/23 09:47:11] INFO (73307, 175) 3510458283.py:4
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO (73307, 11) one_hot_encode.py:58
INFO categories_en 2366952518.py:3
INFO ingredients_tags 2366952518.py:3
INFO ingredients_analysis_tags 2366952518.py:3
INFO traces_en 2366952518.py:3
INFO food_groups_en 2366952518.py:3
INFO nutrient_levels_tags 2366952518.py:3
INFO main_category_en 2366952518.py:3
INFO packaging_en 2366952518.py:3
INFO nutrients 2366952518.py:3
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
import collections
import numpy as np
from sklearn.linear_model import LinearRegression
y = df.select("nutriscore_score").to_numpy().flatten()
logger.info(y.shape)
X = df_for_ml.drop("code").to_numpy()
logger.info(X.shape)
transformer = Normalizer().fit(X)
X_train, X_test, y_train, y_test = train_test_split(
transformer.transform(X), y, test_size=0.10, random_state=2023
)
# clf = tree.DecisionTreeClassifier(max_depth=15)
clf = tree.DecisionTreeRegressor()
# clf = LinearRegression()
clf = clf.fit(X_train, y_train)
df_tree = pl.concat(
[
pl.DataFrame(
{
"actual score": y_test,
"predicted score": clf.predict(X_test),
"label": "test",
}
),
pl.DataFrame(
{
"actual score": y_train,
"predicted score": clf.predict(X_train),
"label": "train",
}
),
]
).with_columns(err=pl.col("predicted score") - pl.col("actual score"))[06/30/23 09:50:24] INFO (73307,) 3204777385.py:9
INFO (73307, 196) 3204777385.py:12
alt.Chart(df_tree).mark_rect(clip=True).encode(
x=alt.X("actual score:Q").bin(step=1).scale(domain=(-15, 40)),
y=alt.Y("predicted score:Q").bin(step=1).scale(domain=(-15, 40)),
color=alt.Color("count():Q").scale(scheme="viridis", reverse=True),
column=alt.Column("label:N"),
).properties(width=500, height=500).resolve_scale(color="independent")| nutriscore_score | |
|---|---|
| saturated-fat-in-low-quantity | -0.52 |
| fat-in-low-quantity | -0.41 |
| Beverages | -0.39 |
| Plant-based foods and beverages | -0.39 |
| Plant-based foods | -0.39 |
| Cereals and potatoesfood_groups_en | -0.33 |
| Cereals and potatoes | -0.27 |
| sugars-in-low-quantity | -0.27 |
| Cereals and their products | -0.27 |
| vegan | -0.25 |
| Fruits and vegetables | -0.23 |
| energy-kcal_100g | 0.21 |
| Nuts | 0.23 |
| dairy | 0.24 |
| disaccharide | 0.24 |
| sugar | 0.25 |
| non-vegan | 0.27 |
| added-sugar | 0.27 |
| energy_100g | 0.34 |
| sugars_100g | 0.37 |
| sugars-in-high-quantity | 0.43 |
| Sugary snacks | 0.43 |
| Snacks | 0.44 |
| Sweet snacks | 0.45 |
| fat_100g | 0.49 |
| saturated-fat_100g | 0.58 |
| fat-in-high-quantity | 0.60 |
| saturated-fat-in-high-quantity | 0.67 |